{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "04a93a5a-3c24-44c8-a9d6-c7b29fb86607",
   "metadata": {},
   "source": [
    "Uses a prompted BERT to learn a metric between passage and query with the goal of solving the passage retrieval task in MS MARCO https://microsoft.github.io/msmarco/\n",
    "\n",
    "This approach was inspired by https://aclanthology.org/D19-1352.pdf and their usage of something similar to a prompted BERT to perform passage retrieval, we experiment with adding a prompt and taking input at the mask token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d2566239-c09d-4414-85ab-0d50a15a924a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "from transformers import BertModel, BertTokenizer\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81c0a87e-7bb8-4fa8-b8c4-d5b977401d8e",
   "metadata": {},
   "source": [
    "# Data Exploration\n",
    "The dataset is massive, near 26GB. Note we use the triples version of the dataset, designed to be convenient for training, which provides triples of the form (query, relevant passage, non relevant passage)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "13158235-9e85-4400-8ad8-08cf6b33b951",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>query</th>\n",
       "      <th>positive</th>\n",
       "      <th>negative</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>is a little caffeine ok during pregnancy</td>\n",
       "      <td>We donât know a lot about the effects of caf...</td>\n",
       "      <td>It is generally safe for pregnant women to eat...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>what fruit is native to australia</td>\n",
       "      <td>Passiflora herbertiana. A rare passion fruit n...</td>\n",
       "      <td>The kola nut is the fruit of the kola tree, a ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>how large is the canadian military</td>\n",
       "      <td>The Canadian Armed Forces. 1  The first large-...</td>\n",
       "      <td>The Canadian Physician Health Institute (CPHI)...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>types of fruit trees</td>\n",
       "      <td>Cherry. Cherry trees are found throughout the ...</td>\n",
       "      <td>The kola nut is the fruit of the kola tree, a ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>how many calories a day are lost breastfeeding</td>\n",
       "      <td>Not only is breastfeeding better for the baby,...</td>\n",
       "      <td>However, you still need some niacin each day; ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>what does petsmart pay part time</td>\n",
       "      <td>Petsmart-About $7 (for a grooming bather or ca...</td>\n",
       "      <td>What is double-time pay? Double-time pay is a ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>what does dystopian mean</td>\n",
       "      <td>Top definition. dystopian. A dystopia (alterna...</td>\n",
       "      <td>(film) Equilibrium is a 2002 American dystopia...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>when is a partnership return due</td>\n",
       "      <td>1 Partnership and S Corporation tax returns wi...</td>\n",
       "      <td>â¢ With decreased stroke volume, due to decre...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>what type of cable is used between switches ?</td>\n",
       "      <td>No thats not possible. Only if you have auto-m...</td>\n",
       "      <td>1-16 of 1,177 results for verizon network type...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>what type of charge is trespassing</td>\n",
       "      <td>If the criminal trespassing charge is a less s...</td>\n",
       "      <td>1-16 of 1,177 results for verizon network type...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10000 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               query  \\\n",
       "0           is a little caffeine ok during pregnancy   \n",
       "1                  what fruit is native to australia   \n",
       "2                 how large is the canadian military   \n",
       "3                               types of fruit trees   \n",
       "4     how many calories a day are lost breastfeeding   \n",
       "...                                              ...   \n",
       "9995                what does petsmart pay part time   \n",
       "9996                        what does dystopian mean   \n",
       "9997                when is a partnership return due   \n",
       "9998   what type of cable is used between switches ?   \n",
       "9999              what type of charge is trespassing   \n",
       "\n",
       "                                               positive  \\\n",
       "0     We donât know a lot about the effects of caf...   \n",
       "1     Passiflora herbertiana. A rare passion fruit n...   \n",
       "2     The Canadian Armed Forces. 1  The first large-...   \n",
       "3     Cherry. Cherry trees are found throughout the ...   \n",
       "4     Not only is breastfeeding better for the baby,...   \n",
       "...                                                 ...   \n",
       "9995  Petsmart-About $7 (for a grooming bather or ca...   \n",
       "9996  Top definition. dystopian. A dystopia (alterna...   \n",
       "9997  1 Partnership and S Corporation tax returns wi...   \n",
       "9998  No thats not possible. Only if you have auto-m...   \n",
       "9999  If the criminal trespassing charge is a less s...   \n",
       "\n",
       "                                               negative  \n",
       "0     It is generally safe for pregnant women to eat...  \n",
       "1     The kola nut is the fruit of the kola tree, a ...  \n",
       "2     The Canadian Physician Health Institute (CPHI)...  \n",
       "3     The kola nut is the fruit of the kola tree, a ...  \n",
       "4     However, you still need some niacin each day; ...  \n",
       "...                                                 ...  \n",
       "9995  What is double-time pay? Double-time pay is a ...  \n",
       "9996  (film) Equilibrium is a 2002 American dystopia...  \n",
       "9997  â¢ With decreased stroke volume, due to decre...  \n",
       "9998  1-16 of 1,177 results for verizon network type...  \n",
       "9999  1-16 of 1,177 results for verizon network type...  \n",
       "\n",
       "[10000 rows x 3 columns]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('Datasets/triples.train.small.tsv', sep='\\t', nrows=10000, # only read a small sample of the data for now \n",
    "                 names=[\"query\", \"positive\", \"negative\"], encoding=\"utf-8\") \n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f397cbb1-9ea1-4248-83d6-35db306d2eb5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'We donâ\\x80\\x99t know a lot about the effects of caffeine during pregnancy on you and your baby. So itâ\\x80\\x99s best to limit the amount you get each day. If youâ\\x80\\x99re pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.iloc[0][\"positive\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5d0ffe25-f6e1-4cb9-9226-4b2fef3ed859",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAQ3UlEQVR4nO3dbYxcV33H8e+vCYEKKCbEtSzb6YZiFUWV8qBVGgRClAiah6pOJYiCKmIiV34TKpBaFVNelEq8CJVKGiQUySVpHURJIh4UC1IKNSDUFwnYEEJIgJjUUWw5sYEQaBHQwL8v5jgZtvswuzu74z37/UijOffcM7vn6K5+c+bcO3dTVUiS+vIbk+6AJGn8DHdJ6pDhLkkdMtwlqUOGuyR16MxJdwDgnHPOqampqUl3Q5LWlEOHDn2/qjbOtu+0CPepqSkOHjw46W5I0pqS5LG59rksI0kdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjo0Urgn2ZDk40m+neThJK9KcnaSzyd5pD2/tLVNkg8mOZzkgSQXr+wQJEkzjTpzvxn4bFW9ErgAeBjYAxyoqu3AgbYNcAWwvT12A7eMtceSpAUtGO5JXgK8FrgVoKp+UVU/AnYA+1qzfcDVrbwDuL0G7gU2JNk85n5LkuYxyjdUzwNOAv+c5ALgEPAOYFNVHW9tngA2tfIW4PGh1x9tdceH6kiym8HMnnPPPXep/V+XpvZ85tnykRuvmmBPJJ2uRlmWORO4GLilqi4C/ofnlmAAqMG/c1rUv3Sqqr1VNV1V0xs3znprBEnSEo0S7keBo1V1X9v+OIOwf/LUckt7PtH2HwO2Db1+a6uTJK2SBcO9qp4AHk/ye63qMuAhYD+ws9XtBO5u5f3Ade2qmUuBp4eWbyRJq2DUu0L+BfDRJGcBjwLXM3hjuCvJLuAx4JrW9h7gSuAw8NPWVpK0ikYK96q6H5ieZddls7Qt4IbldUszDZ9ElaSF+A1VSeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHRv1PTFolw/+U48iNV02wJ5LWMmfuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUoZG+oZrkCPAT4JfAM1U1neRs4E5gCjgCXFNVTyUJcDNwJfBT4G1V9bXxd10z+e1WSacsZub+h1V1YVVNt+09wIGq2g4caNsAVwDb22M3cMu4OitJGs1ylmV2APtaeR9w9VD97TVwL7AhyeZl/B5J0iKNGu4FfC7JoSS7W92mqjreyk8Am1p5C/D40GuPtrpfk2R3koNJDp48eXIJXZckzWXUu0K+pqqOJflt4PNJvj28s6oqSS3mF1fVXmAvwPT09KJeK0ma30gz96o61p5PAJ8CLgGePLXc0p5PtObHgG1DL9/a6iRJq2TBcE/ywiQvPlUG3gg8COwHdrZmO4G7W3k/cF0GLgWeHlq+kSStglGWZTYBnxpc4ciZwL9W1WeTfBW4K8ku4DHgmtb+HgaXQR5mcCnk9WPvdWeGL2GUpHFYMNyr6lHgglnqfwBcNkt9ATeMpXeSpCXxG6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KGRwz3JGUm+nuTTbfu8JPclOZzkziRntfrnt+3Dbf/UCvVdkjSHxczc3wE8PLT9fuCmqnoF8BSwq9XvAp5q9Te1dpKkVTRSuCfZClwFfLhtB3g98PHWZB9wdSvvaNu0/Ze19pKkVTLqzP0fgb8GftW2Xwb8qKqeadtHgS2tvAV4HKDtf7q1/zVJdic5mOTgyZMnl9Z7SdKszlyoQZI/Bk5U1aEkrxvXL66qvcBegOnp6RrXz9XA1J7PPFs+cuNVE+yJpElYMNyBVwN/kuRK4AXAbwE3AxuSnNlm51uBY639MWAbcDTJmcBLgB+MveeSpDktuCxTVe+uqq1VNQVcC3yhqv4M+CLwptZsJ3B3K+9v27T9X6gqZ+aStIpGmbnP5V3AHUneB3wduLXV3wp8JMlh4IcM3hC0BMNLK5K0GIsK96r6EvClVn4UuGSWNj8D3jyGvkmSlmg5M3ctkic5Ja0Wbz8gSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkd8huqa5z3n5E0G2fuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtSh7y3zArz3i+SJsGZuyR1yJn7OjP8SeLIjVdNsCeSVtKCM/ckL0jylSTfSPKtJH/X6s9Lcl+Sw0nuTHJWq39+2z7c9k+t8BgkSTOMsizzc+D1VXUBcCFweZJLgfcDN1XVK4CngF2t/S7gqVZ/U2snSVpFC4Z7Dfx323xeexTweuDjrX4fcHUr72jbtP2XJcm4OixJWthIa+5JzgAOAa8APgR8D/hRVT3TmhwFtrTyFuBxgKp6JsnTwMuA74+x32ueV9FIWkkjXS1TVb+sqguBrcAlwCuX+4uT7E5yMMnBkydPLvfHSZKGLOpSyKr6EfBF4FXAhiSnZv5bgWOtfAzYBtD2vwT4wSw/a29VTVfV9MaNG5fWe0nSrEa5WmZjkg2t/JvAG4CHGYT8m1qzncDdrby/bdP2f6Gqaox9liQtYJQ1983Avrbu/hvAXVX16SQPAXckeR/wdeDW1v5W4CNJDgM/BK5dgX5LkuaxYLhX1QPARbPUP8pg/X1m/c+AN4+ld5KkJfH2A5LUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUof8B9nrgP8YRFp/DPd1bDj0j9x41QR7ImncXJaRpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdcgvMWlWfsFJWtucuUtShwx3SeqQ4S5JHXLNXQty/V1aexYM9yTbgNuBTUABe6vq5iRnA3cCU8AR4JqqeipJgJuBK4GfAm+rqq+tTPdPHwagpNPJKMsyzwB/WVXnA5cCNyQ5H9gDHKiq7cCBtg1wBbC9PXYDt4y915KkeS0Y7lV1/NTMu6p+AjwMbAF2APtas33A1a28A7i9Bu4FNiTZPO6OS5LmtqgTqkmmgIuA+4BNVXW87XqCwbINDIL/8aGXHW11M3/W7iQHkxw8efLkYvstSZrHyOGe5EXAJ4B3VtWPh/dVVTFYjx9ZVe2tqumqmt64ceNiXipJWsBI4Z7keQyC/aNV9clW/eSp5Zb2fKLVHwO2Db18a6uTJK2SBcO9Xf1yK/BwVX1gaNd+YGcr7wTuHqq/LgOXAk8PLd9IklbBKNe5vxp4K/DNJPe3ur8BbgTuSrILeAy4pu27h8FlkIcZXAp5/Tg7LEla2ILhXlX/CWSO3ZfN0r6AG5bZL0nSMvgN1RUw/IUmSZoE7y0jSR1y5q4V5W0ZpMlw5i5JHXLmvgzrfW3dWbl0+jLcNXG+SUjjZ7hrLAxo6fTimrskdchwl6QOuSyjNcFlH2lxnLlLUoecueu04gxdGg/DXYDX7Eu9cVlGkjrkzF2rxiUXafU4c5ekDhnuktQhw12SOmS4S1KHPKGqsfOySmnynLlLUoecuasbXmopPceZuyR1yHCXpA4Z7pLUoQXDPcltSU4keXCo7uwkn0/ySHt+aatPkg8mOZzkgSQXr2TnJUmzG2Xm/i/A5TPq9gAHqmo7cKBtA1wBbG+P3cAt4+mm9JypPZ959iFpdgteLVNVX04yNaN6B/C6Vt4HfAl4V6u/vaoKuDfJhiSbq+r42HqsLhjM0spa6qWQm4YC+wlgUytvAR4fane01f2/cE+ym8HsnnPPPXeJ3dBqM5SltWHZ17lXVSWpJbxuL7AXYHp6etGvl+bjNe9a75Ya7k+eWm5Jshk40eqPAduG2m1tdWuaQSFprVnqpZD7gZ2tvBO4e6j+unbVzKXA0663S9LqW3DmnuRjDE6enpPkKPC3wI3AXUl2AY8B17Tm9wBXAoeBnwLXr0CfJUkLGOVqmbfMseuyWdoWcMNyO3U68MTh5HkMpKXzxmFa0xb7BuD5E60XhvsiOZuUtBYY7lq3nMWrZ944TJI65Mxd3XMpTeuR4S7N4HKNeuCyjCR1yHCXpA4Z7pLUIdfcJTzpqv44c5ekDhnuktQhl2WkecxcrvHSSK0VhvsQ1121VF4br9ONyzKS1CFn7tIiLHaG7oxek+LMXZI6tO5n7q6za6n829HpbN2Hu7Ra5lqicelGK8FlGUnq0LqbuftRWtJ64Mxdkjq0LmbuztYlrTfdhruBrkkZ5W9vsW080arF6jbcpZ4Y9Fosw106jSxnRu8bgIatSLgnuRy4GTgD+HBV3bgSv0fS7JbzBuCbRB9SVeP9gckZwHeBNwBHga8Cb6mqh+Z6zfT0dB08eHBJv8+1dWn55noDGKX9zNf4hrB6khyqqunZ9q3EzP0S4HBVPdp++R3ADmDOcJc0WYudJM3XfrFvDstpv9g3ksV+olnOjeLme81qvBmuxMz9TcDlVfXnbfutwB9U1dtntNsN7G6bvw88ONaOnH7OAb4/6U6sMMfYB8e4dvxOVW2cbcfETqhW1V5gL0CSg3N9tOiFY+yDY+zDehjjSnxD9RiwbWh7a6uTJK2SlQj3rwLbk5yX5CzgWmD/CvweSdIcxr4sU1XPJHk78O8MLoW8raq+tcDL9o67H6chx9gHx9iH7sc49hOqkqTJ866QktQhw12SOjTxcE9yeZLvJDmcZM+k+zMuSY4k+WaS+5McbHVnJ/l8kkfa80sn3c/FSHJbkhNJHhyqm3VMGfhgO64PJLl4cj0f3RxjfG+SY+1Y3p/kyqF9725j/E6SP5pMrxcnybYkX0zyUJJvJXlHq+/mWM4zxq6O5byqamIPBidcvwe8HDgL+AZw/iT7NMaxHQHOmVH398CeVt4DvH/S/VzkmF4LXAw8uNCYgCuBfwMCXArcN+n+L2OM7wX+apa257e/2ecD57W/5TMmPYYRxrgZuLiVX8zgdiHn93Qs5xljV8dyvsekZ+7P3qqgqn4BnLpVQa92APtaeR9w9eS6snhV9WXghzOq5xrTDuD2GrgX2JBk86p0dBnmGONcdgB3VNXPq+q/gMMM/qZPa1V1vKq+1so/AR4GttDRsZxnjHNZk8dyPpMO9y3A40PbR5n/AKwlBXwuyaF2qwWATVV1vJWfADZNpmtjNdeYeju2b29LErcNLaet+TEmmQIuAu6j02M5Y4zQ6bGcadLh3rPXVNXFwBXADUleO7yzBp8Fu7oOtccxNbcAvwtcCBwH/mGivRmTJC8CPgG8s6p+PLyvl2M5yxi7PJazmXS4d3urgqo61p5PAJ9i8BHvyVMfZ9vzicn1cGzmGlM3x7aqnqyqX1bVr4B/4rmP62t2jEmexyD0PlpVn2zVXR3L2cbY47Gcy6TDvctbFSR5YZIXnyoDb2Rw18v9wM7WbCdw92R6OFZzjWk/cF270uJS4Omhj/xryoz15T/luTuY7geuTfL8JOcB24GvrHb/FitJgFuBh6vqA0O7ujmWc42xt2M5r0mf0WVwJv67DM5Ov2fS/RnTmF7O4Mz7N4BvnRoX8DLgAPAI8B/A2ZPu6yLH9TEGH2X/l8Ga5K65xsTgyooPteP6TWB60v1fxhg/0sbwAIMQ2DzU/j1tjN8Brph0/0cc42sYLLk8ANzfHlf2dCznGWNXx3K+h7cfkKQOTXpZRpK0Agx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KH/A4NgYiTivsHqAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# plot a passage length histogram\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "lengths = df[\"positive\"].apply(lambda x: len(tokenizer.encode(x)))\n",
    "plt.hist(lengths, bins=100)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f67f9755-cb10-4b2e-8f3d-e4117bfa1e14",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length=256 # from the plot we can see a batch size of 256 is more than enough to represent most passages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c5f9d476-6425-49e6-b1d5-4074cd995556",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a train test split by first creating a dataset with query, sample and label. Then splitting it\n",
    "positive = df[[\"query\", \"positive\"]].rename(columns={\"positive\": \"passage\"})\n",
    "positive[\"label\"] = 1\n",
    "negative = df[[\"query\", \"negative\"]].rename(columns={\"negative\": \"passage\"})\n",
    "negative[\"label\"] = 0\n",
    "dataset = pd.concat([positive, negative], axis=0).reset_index(drop=True)\n",
    "dataset = dataset.sample(20000, replace=False)\n",
    "X_train, X_test, y_train, y_test = train_test_split(dataset[[\"query\", \"passage\"]], np.array(dataset[\"label\"]), test_size=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "19502ce5-9ec3-4c12-b92a-bd9b67176c6f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>query</th>\n",
       "      <th>passage</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>4769</th>\n",
       "      <td>how many bundles of shingles in a square</td>\n",
       "      <td>The most common type of shingles, known as str...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7968</th>\n",
       "      <td>what is the name of the turkey neck</td>\n",
       "      <td>A turkey neck is the slang term that has been ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6678</th>\n",
       "      <td>how long do you microwave a potato to cook it</td>\n",
       "      <td>1 If you prefer crispy skinned sweet potatoes,...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13510</th>\n",
       "      <td>who pays for kidney transplant</td>\n",
       "      <td>Ciprofloxacin may cause swelling or tearing of...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2868</th>\n",
       "      <td>when was a man called ove released</td>\n",
       "      <td>A Man Called Ove (Swedish: En man som heter Ov...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11128</th>\n",
       "      <td>what is appropriate in your field dress</td>\n",
       "      <td>Dress. When you create a dress code policy for...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9089</th>\n",
       "      <td>where is stow ohio located</td>\n",
       "      <td>Stow is a city in Summit County, Ohio, United ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8642</th>\n",
       "      <td>how many species of deer</td>\n",
       "      <td>Itâs thought there are more than two million...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6610</th>\n",
       "      <td>define skeleton</td>\n",
       "      <td>skeleton. 1  the hard framework of an animal b...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19402</th>\n",
       "      <td>how much to tip peapod driver</td>\n",
       "      <td>Easy to-go hummus #recipe for truck drivers or...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>16000 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                               query  \\\n",
       "4769        how many bundles of shingles in a square   \n",
       "7968             what is the name of the turkey neck   \n",
       "6678   how long do you microwave a potato to cook it   \n",
       "13510                 who pays for kidney transplant   \n",
       "2868              when was a man called ove released   \n",
       "...                                              ...   \n",
       "11128        what is appropriate in your field dress   \n",
       "9089                      where is stow ohio located   \n",
       "8642                        how many species of deer   \n",
       "6610                                 define skeleton   \n",
       "19402                  how much to tip peapod driver   \n",
       "\n",
       "                                                 passage  \n",
       "4769   The most common type of shingles, known as str...  \n",
       "7968   A turkey neck is the slang term that has been ...  \n",
       "6678   1 If you prefer crispy skinned sweet potatoes,...  \n",
       "13510  Ciprofloxacin may cause swelling or tearing of...  \n",
       "2868   A Man Called Ove (Swedish: En man som heter Ov...  \n",
       "...                                                  ...  \n",
       "11128  Dress. When you create a dress code policy for...  \n",
       "9089   Stow is a city in Summit County, Ohio, United ...  \n",
       "8642   Itâs thought there are more than two million...  \n",
       "6610   skeleton. 1  the hard framework of an animal b...  \n",
       "19402  Easy to-go hummus #recipe for truck drivers or...  \n",
       "\n",
       "[16000 rows x 2 columns]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c8cf4451-7484-48ea-8e5f-a1a36fea17e9",
   "metadata": {},
   "source": [
    "# Modelling\n",
    "Define the dataset class and a BERT_prompt class, with prompting logic in the dataset for speed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b5809a0e-71ab-4935-ba97-755083ca1e59",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "bert = BertModel.from_pretrained('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "28db62b9-4940-463f-86bf-dca165c6435e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[101, 7592, 103, 102]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.encode(\"Hello \" + tokenizer.mask_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6619ef60-6bff-42b7-866b-27ed2fd1e2f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a dataset class for prompted BERT fine tuning\n",
    "# The dataset takes in query and passage, and then construct a training sample as: <query> + <prompt> + [MASK] + <passage>, returning the position of the mask\n",
    "# It also performs padding and stores the labels\n",
    "class BERTPromptDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, dataset, labels, tokenizer, prompt=\"Is the following text relevant?\", max_length=max_length):\n",
    "        # Construct the input sentence\n",
    "        input_sentences = [\"{}. {} {} {}\".format(query, prompt, tokenizer.mask_token, passage) for _, (query, passage) in dataset.iterrows()]\n",
    "        \n",
    "        # Encode and store\n",
    "        encodings_dict = tokenizer.batch_encode_plus(input_sentences, truncation=True, max_length=max_length, padding=\"max_length\")\n",
    "        self.input_ids = encodings_dict['input_ids']\n",
    "        self.attn_masks = encodings_dict['attention_mask']\n",
    "        self.labels = labels\n",
    "\n",
    "        # Calculate the position of the mask using self.input_ids\n",
    "        mask_id = tokenizer.encode(tokenizer.mask_token)[1] # 103\n",
    "        self.mask_pos = [sent_ids.index(mask_id) for sent_ids in self.input_ids]\n",
    "        \n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.input_ids)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return_dict = {\"input_ids\": torch.tensor(self.input_ids[idx]),\n",
    "                       \"attention_mask\": torch.tensor(self.attn_masks[idx]), \n",
    "                       \"mask_pos\": torch.tensor(self.mask_pos[idx]),\n",
    "                       \"labels\": torch.tensor(self.labels[idx]).float()} \n",
    "        return return_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8eda4e54-2b13-4c15-9e73-e88b0dfac008",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'input_ids': tensor([[ 101, 2129, 2116,  ...,    0,    0,    0],\n",
       "         [ 101, 2054, 2003,  ...,    0,    0,    0],\n",
       "         [ 101, 2129, 2146,  ...,    0,    0,    0],\n",
       "         ...,\n",
       "         [ 101, 2054, 2193,  ...,    0,    0,    0],\n",
       "         [ 101, 6032, 3465,  ...,    0,    0,    0],\n",
       "         [ 101, 2129, 2000,  ...,    0,    0,    0]]),\n",
       " 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         ...,\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0],\n",
       "         [1, 1, 1,  ..., 0, 0, 0]]),\n",
       " 'mask_pos': tensor([17, 16, 18, 13, 16, 21, 13, 12, 14, 16]),\n",
       " 'labels': tensor([1., 1., 1., 0., 1., 0., 0., 0., 1., 1.])}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset = BERTPromptDataset(X_train, y_train, tokenizer)\n",
    "dataset[0:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "742873b5-416c-45ed-8e67-3a5b14c41b74",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class that expects a prompted input from the BERTPromptDataset\n",
    "# Takes the input, forward propagates it through BERT and concatenates the output at the [MASK] and [CLS] token to get a representation of the text\n",
    "# This is then passed to a linear head to perform binary classification for how relevant it is\n",
    "class BERTPrompt(torch.nn.Module):\n",
    "    def __init__(self, bert, tokenizer):\n",
    "        super().__init__()\n",
    "        self.bert = bert.cuda()\n",
    "        self.tokenizer = tokenizer\n",
    "        self.linear = torch.nn.Linear(768*2, 1).cuda()\n",
    "        self.act = torch.nn.Sigmoid()\n",
    "        \n",
    "    # input_dict is obtained through indexing the dataset e.g. dataset[0:10]\n",
    "    def forward(self, input_dict):\n",
    "        output = self.bert(input_dict[\"input_ids\"].cuda(), attention_mask=input_dict['attention_mask'].cuda())[0] # output is of shape [10, 256, 768]\n",
    "        cls_out = output[:, 0, :]\n",
    "        mask_out = output[torch.arange(cls_out.shape[0]).cuda(), input_dict[\"mask_pos\"].cuda(), :] # indexing like [[0, 1], [13, 14]] will select items [[0, 13], [1, 14]]\n",
    "        \n",
    "        representation = torch.cat([cls_out, mask_out], dim=1)\n",
    "        logit = self.linear(representation)\n",
    "        return logit\n",
    "    \n",
    "    # return the logit with sigmoid activation applied\n",
    "    def predict(self, input_dict):\n",
    "        return self.act(self.forward(input_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "b26ff588-637e-44d0-8f65-e57f966efc38",
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_prompt = BERTPrompt(bert, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "01658e0b-45b4-451a-863d-ade62d1d6156",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2016],\n",
       "        [ 0.1476],\n",
       "        [ 0.1437],\n",
       "        [ 0.2560],\n",
       "        [-0.0234],\n",
       "        [ 0.1067],\n",
       "        [-0.0392],\n",
       "        [ 0.0778],\n",
       "        [ 0.0424],\n",
       "        [ 0.0211]], device='cuda:0')"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    output = bert_prompt.forward(dataset[:10])\n",
    "output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf90a884-0f1a-462e-aca3-faf0cb6f9d78",
   "metadata": {},
   "source": [
    "# Training\n",
    "Train with the entire dataset. Few shot learning is not that much of a problem in document retrieval, more adaptation and generalization to queries that are often unique"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "5c30ac00-1996-4c0d-ab38-ca3c5228304c",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "d1ba78f1-2a6c-4bd6-a8ad-fc05de5dc3e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_accuracy(model, test_dataset):\n",
    "    model.eval()\n",
    "    preds = []\n",
    "    labels = []\n",
    "    with torch.no_grad():\n",
    "        for batch_i in range(int(len(test_dataset)/batch_size)):\n",
    "            batch = test_dataset[batch_i*batch_size: (batch_i+1)*batch_size]\n",
    "            preds += list(bert_prompt.predict(batch)[:, 0].cpu().numpy())\n",
    "            labels += list(batch[\"labels\"].cpu().numpy())\n",
    "    return accuracy_score(np.array(preds) > 0.5,  np.array(labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "77126638-542a-4582-b53c-229dc0b04907",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = BERTPromptDataset(X_train[:1000], y_train[:1000], tokenizer)\n",
    "test_dataset = BERTPromptDataset(X_test[:1000], y_test[:1000], tokenizer)\n",
    "optimizer = torch.optim.Adam(bert_prompt.parameters(), lr=1e-5, betas=(0.95, 0.9995))\n",
    "loss_func = torch.nn.BCELoss()\n",
    "epochs = 1\n",
    "batch_size = 16\n",
    "bert_prompt = BERTPrompt(bert, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "c77805bd-95cf-4b2c-9b4d-d32a00576dd7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch 0 Train Loss: 0.8332217335700989\n",
      "Epoch 0 Train Loss: 0.023418264463543892 Validation Accuracy: 0.8790322580645161\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(epochs):  \n",
    "    bert_prompt.train()\n",
    "    for batch_i in range(int(len(train_dataset)/batch_size)): # just ignore the final partial batch\n",
    "        bert_prompt.zero_grad()\n",
    "        batch = train_dataset[batch_i*batch_size: (batch_i+1)*batch_size]\n",
    "        preds = bert_prompt.predict(batch) # gives outputs with sigmoid activation\n",
    "        loss = loss_func(preds[:, 0], batch[\"labels\"].cuda())\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if batch_i % 100 == 0:\n",
    "            print(\"Batch {} Train Loss: {}\".format(batch_i, loss))\n",
    "        \n",
    "    torch.cuda.empty_cache()\n",
    "    accuracy = calc_accuracy(bert_prompt, test_dataset)\n",
    "    print(\"Epoch {} Train Loss: {} Validation Accuracy: {}\".format(epoch, loss, accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68b20553-2d2e-4d5b-8581-b292f3fbb4e1",
   "metadata": {},
   "source": [
    "With prompt (1000 examples, 1 epoch): 0.879 accuracy\n",
    "\n",
    "Without prompt (1000 examples, 1 epoch): 0.865 accuracy\n",
    "\n",
    "This at least shows a better few shot learning potential, need to do more testing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53ab0faa-20d9-477a-871d-80e7f3681288",
   "metadata": {},
   "source": [
    "<code>Batch 0 Train Loss: 0.10379621386528015\n",
    "Epoch 0 Train Loss: 0.14899209141731262 Validation Accuracy: 0.8276209677419355\n",
    "Batch 0 Train Loss: 0.013592392206192017\n",
    "Epoch 1 Train Loss: 0.09803555905818939 Validation Accuracy: 0.8306451612903226\n",
    "Batch 0 Train Loss: 0.0046791210770606995\n",
    "Epoch 2 Train Loss: 0.02600925974547863 Validation Accuracy: 0.8215725806451613\n",
    "Batch 0 Train Loss: 0.002842813730239868\n",
    "Epoch 3 Train Loss: 0.03294769302010536 Validation Accuracy: 0.8004032258064516\n",
    "Batch 0 Train Loss: 0.0033466718159615993\n",
    "Epoch 4 Train Loss: 0.013205558061599731 Validation Accuracy: 0.8165322580645161\n",
    "Batch 0 Train Loss: 0.004724848549813032\n",
    "Epoch 5 Train Loss: 0.01365710236132145 Validation Accuracy: 0.8165322580645161\n",
    "Batch 0 Train Loss: 0.003022630698978901\n",
    "Epoch 6 Train Loss: 0.01209681760519743 Validation Accuracy: 0.8245967741935484\n",
    "Batch 0 Train Loss: 0.004604271147400141\n",
    "Epoch 7 Train Loss: 0.017215438187122345 Validation Accuracy: 0.8225806451612904\n",
    "Batch 0 Train Loss: 0.0008006490534171462\n",
    "Epoch 8 Train Loss: 0.019000805914402008 Validation Accuracy: 0.8195564516129032\n",
    "Batch 0 Train Loss: 0.0017620337894186378\n",
    "Epoch 9 Train Loss: 0.00176604138687253 Validation Accuracy: 0.8326612903225806</code>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
